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Abstract 


Learning to predict multi-label outputs is challenging, but in many problems there 
is a natural metric on the outputs that can be used to improve predictions. In this 
paper we develop a loss function for multi-label learning, based on the Wasserstein 
distance. The Wasserstein distance provides a natural notion of dissimilarity for 
probability measures. Although optimizing with respect to the exact Wasserstein 
distance is costly, recent work has described a regularized approximation that is 
efficiently computed. We describe an efficient learning algorithm based on this 
regularization, as well as a novel extension of the Wasserstein distance from prob¬ 
ability measures to unnormalized measures. We also describe a statistical learning 
bound for the loss. The Wasserstein loss can encourage smoothness of the predic¬ 
tions with respect to a chosen metric on the output space. We demonstrate this 
property on a real-data tag prediction problem, using the Yahoo Flickr Creative 
Commons dataset, outperforming a baseline that doesn’t use the metric. 


1 Introduction 

We consider the problem of learning to predict a non-negative measure over a finite set. This prob¬ 
lem includes many common machine learning scenarios. In multiclass classification, for example, 
one often predicts a vector of scores or probabilities for the classes. And in semantic segmenta¬ 
tion m, one can model the segmentation as being the support of a measure defined over the pixel 
locations. Many problems in which the output of the learning machine is both non-negative and 
multi-dimensional might be cast as predicting a measure. 

We specifically focus on problems in which the output space has a natural metric or similarity struc¬ 
ture, which is known (or estimated) a priori. In practice, many learning problems have such struc¬ 
ture. In the ImageNet Large Scale Visual Recognition Challenge [ILSVRC] 0 , for example, the 
output dimensions correspond to 1000 object categories that have inherent semantic relationships, 
some of which are captured in the WordNet hierarchy that accompanies the categories. Similarly, in 
the keyword spotting task from the IARPA Babel speech recognition project, the outputs correspond 
to keywords that likewise have semantic relationships. In what follows, we will call the similarity 
structure on the label space the ground metric or semantic similarity. 

Using the ground metric, we can measure prediction performance in a way that is sensitive to re¬ 
lationships between the different output dimensions. For example, confusing dogs with cats might 

* Authors contributed equally. 

^ode and data are available at http : //cbcl .mit. edu/wasserstein, 
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Figure 2: The Wasserstein loss encourages predictions that are similar to ground truth, robustly 
to incorrect labeling of similar classes (see Appendix |E.1| ). Shown is Euclidean distance between 
prediction and ground truth vs. (left) number of classes, averaged over different noise levels and 
(right) noise level, averaged over number of classes. Baseline is the multiclass logistic loss. 


be more severe an error than confusing breeds of dogs. A loss function that incorporates this metric 
might encourage the learning algorithm to favor predictions that are, if not completely accurate, at 
least semantically similar to the ground truth. 


In this paper, we develop a loss function for multi-label learn¬ 
ing that measures the Wasserstein distance between a prediction 
and the target label, with respect to a chosen metric on the out¬ 
put space. The Wasserstein distance is defined as the cost of 
the optimal transport plan for moving the mass in the predicted 

. 4. u 4. • 4-u 4. . , 1 u V A 4r Siberian husky Eskimo dog 

measure to match that in the target, and has been applied to a 

wide range of problems, including barycenter estimation 0 , la- Figure 1: Semantically near- 
bel propagation 0, and clustering 10. To our knowledge, this equivalent classes in ILSVRC 
paper represents the first use of the Wasserstein distance as a 
loss for supervised learning. 



We briefly describe a case in which the Wasserstein loss improves learning performance. The setting 
is a multiclass classification problem in which label noise arises from confusion of semantically 
near-equivalent categories. Figure [I] shows such a case from the ILSVRC, in which the categories 
Siberian husky and Eskimo dog are nearly indistinguishable. We synthesize a toy version of this 
problem by identifying categories with points in the Euclidean plane and randomly switching the 
training labels to nearby classes. The Wasserstein loss yields predictions that are closer to the ground 
truth, robustly across all noise levels , as s hown in Figure [2] The standard multiclass logistic loss is 
the baseline for comparison. Section E.l in the Appendix describes the experiment in more detail. 


The main contributions of this paper are as follows. We formulate the problem of learning with prior 
knowledge of the ground metric, and propose the Wasserstein loss as an alternative to traditional 
information divergence-based loss functions. Specifically, we focus on empirical risk minimization 
(ERM) with the Wasserstein loss, and describe an efficient learning algorithm based on entropic 
regularization of the optimal transport problem. We also describe a novel extension to unnormalized 
measures that is similarly efficient to compute. We then justify ERM with the Wasserstein loss 
by showing a statistical learning bound. Finally, we evaluate the proposed loss on both synthetic 
examples and a real-world image annotation problem, demonstrating benefits for incorporating an 
output metric into the loss. 


2 Related work 

Decomposable loss functions like KL Divergence and i v distances are very popular for probabilis¬ 
tic m or vector-valued (6[ predictions, as each component can be evaluated independently, often 
leading to simple and efficient algorithms. The idea of exploiting smoothness in the label space 
according to a prior metric has been explored in many different forms, including regularization 0 
and post-processing with graphical models I®. Optimal transport provides a natural distance for 
probability distributions over metric spaces. In E0, the optimal transport is used to formulate 
the Wasserstein barycenter as a probability distribution with minimum total Wasserstein distance 
to a set of given points on the probability simplex. 0 propagates histogram values on a graph by 
minimizing a Dirichlet energy induced by optimal transport. The Wasserstein distance is also used 
to formulate a metric for comparing clusters in 0, and is applied to image retrieval 11 1 Ok contour 
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matching |[TTi . and many other problems ff2l:131. However, to our knowledge, this is the first time 
it is used as a loss function in a discriminative learning framework. The closest work to this pa¬ 
per is a theoretical study [14 ] of an estimator that minimizes the optimal transport cost between the 
empirical distribution and the estimated distribution in the setting of statistical parameter estimation. 

3 Learning with a Wasserstein loss 

3.1 Problem setup and notation 

We consider the problem of learning a map from X c R D into the space y = of measures over 
a finite set JC of size \JC\ = K. Assume JC possesses a metric d^(-, •), which is called the ground 
metric, djc measures semantic similarity between dimensions of the output, which correspond to 
the elements of JC. We perform learning over a hypothesis space H of predictors ho : y, 

parameterized by 0 E 0. These might be linear logistic regression models, for example. 

In the standard statistical learning setting, we get an i.i.d. sequence of training examples S = 
((# 1 , yi ),..., (xn, Vn)), sampled from an unknown joint distribution Vxxy • Given a measure of 
performance (a.k.a. risk ) £(•,•), the goal is to find the predictor ho G % that minimizes the expected 
risk E [£(ho(x),y)\. Typically £(-,-) is difficult to optimize directly and the joint distribution Vxxy 
is unknown, so learning is performed via empirical risk minimization. Specifically, we solve 

min j t s [£(h e (x),y) = -f (1) 

with a loss function £(-,-) acting as a surrogate of £(•, •). 

3.2 Optimal transport and the exact Wasserstein loss 

Information divergence-based loss functions are widely used in learning with probability-valued out¬ 
puts. Along with other popular measures like Hellinger distance and y 2 distance, these divergences 
treat the output dimensions independently, ignoring any metric structure on JC. 

Given a cost function c : JC x JC R, the optimal transport distance na measures the cheapest 
way to transport the mass in probability measure pi to match that in /i 2 : 

Vk c (/ii,/i 2 ) = inf / c(/€i,« 2 )7(d/€i,d« 2 ) (2) 

7enOi,/x 2 ) J/cx/c 

where H(pi , /i 2 ) is the set of joint probability measures on JC x JC having /ii and /i 2 as marginals. An 
important case is that in which the cost is given by a metric djc (•, •) or its p- th power (•, •) with p > 
1. In this case, © is called a Wasserstein distance ED, also known as the earth mover’s distance 
ED. In this paper, we only work with discrete measures. In the case of probability measures, these 
are histograms in the simplex A^. When the ground truth y and the output of h both lie in the 
simplex A*', we can define a Wasserstein loss. 

Definition 3.1 (Exact Wasserstein Loss). For any ho G H, ho : X A^, let ho(n\x) = ho(x) K be 
the predicted value at element n G JC, given input x G X. Let y(n) be the ground truth value for n 
given by the corresponding label y. Then we define the exact Wasserstein loss as 

WP(h(-\x),y(-))= inf (T,M) (3) 

Ten(h(x),y) 

where M € is the distance matrix M KK > = d^n, k'), and the set of valid transport plans is 

U(h(x),y) = {Te : T\ = h(x), T t 1 = y} (4) 

where 1 is the all-one vector. 

Wg is the cost of the optimal plan for transporting the predicted mass distribution h{x) to match 
the target distribution y. The penalty increases as more mass is transported over longer distances, 
according to the ground metric M. 
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Algorithm 1 Gradient of the Wasserstein loss 


Given h(x), y. A, K. ( 7 a , 7 ^ if h(x), y unnormalized.) 

U 4r- 1 

while u has not converged do 


u < 


h(x) 0 (K (i/ 0K t «)) 


h{x)-^X+i 0 | K (y 0 K T tt) 


7a A + l 


if /i(x), y normalized 
if h(x), y unnormalized 


end while 

7foA l b X 

If h(x), y unnormalized: v <— t /^ a + 1 0 (K u) 7 & A+1 


dW»/dh(x)<^[ if h(x), y normalized 

p \ 7 a (1 — (diag(/u)Ku) 0 h{pc)) if h(x), y unnormalized 


4 Efficient optimization via entropic regularization 

To do learning, we optimize the empirical risk minimization functional 0 by gradient descent. 
Doing so requires evaluating a descent direction for the loss, with respect to the predictions h(x). 
Unfortunately, computing a subgradient of the exact Wasserstein loss 0, is quite costly, as follows. 

The exact Wasserstein loss ([3]) is a linear program and a subgradient of its solution can be computed 
using Lagrange duality. The dual LP of ^ is 

d Wp{h(x),y) = sup a T h(x) + /3 T y, C M = {(a,/3) eR KxK : a K + /3^ < M KtK >}. (5) 
ck,/3gCm 

As 0 is a linear program, at an optimum the values of the dual and the primal are equal (see, e.g. 
G3), hence the dual optimal a is a subgradient of the loss with respect to its first argument. 

Computing a is costly, as it entails solving a linear program with 0(K 2 ) contraints, with K being 
the dimension of the output space. This cost can be prohibitive when optimizing by gradient descent. 

4.1 Entropic regularization of optimal transport 

Cuturi m proposes a smoothed transport objective that enables efficient approximation of both the 
transport matrix in ^ and the subgradient of the loss. lH~8l introduces an entropic regularization 
term that results in a strictly convex problem: 

x W£(h(-\ X ),y(-))= inf (T,M)-\h(T), H(T) = - logT*,*,. (6) 

T£U(h(x),y) A *—' 

K,,K,' 

Importantly, the transport matrix that solves ([6]) is a diagonal scaling of a matrix K = e _AM_1 : 

T* = diagO)Kdiag(» (7) 

for u = e Xa and v = e A/3 , where a and /? are the Lagrange dual variables for 0. 

Identifying such a matrix subject to equality constraints on the row and column sums is exactly a 
matrix balancing problem, which is well-studied in numerical linear algebra and for which efficient 
iterative algorithms exist CD. E) and m use the well-known Sinkhom-Knopp algorithm. 

4.2 Extending smoothed transport to the learning setting 

When the output vectors h{x) and y lie in the simplex, 0 can be used directly in place of 0, as 
0 can approximate the exact Wasserstein distance closely for large enough A IS)- In this case, the 

gradient a of the objective can be obtained from the optimal scaling vector u as a = — log ^ K 1 1. 

Q A Sinkhorn iteration for the gradient is given in Algorithm [I] 

'Note that a is only defined up to a constant shift: any upscaling of the vector u can be paired with a 
corresponding downscaling of the vector v (and vice versa) without altering the matrix T* . The choice a m 

— 1 ° S \k 1 1 ensures that a is tangent to the simplex. 
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(a) Convergence to smoothed trans¬ 
port. 



(b) Approximation of exact 
Wasserstein. 


log, dimension 

(c) Convergence of alternating pro¬ 
jections (A = 50). 


Figure 3: The relaxed transport problem ^ for unnormalized measures. 


For many learning problems, however, a normalized output assumption is unnatural. In image seg¬ 
mentation, for example, the target shape is not naturally represented as a histogram. And even when 
the prediction and the ground truth are constrained to the simplex, the observed label can be subject 
to noise that violates the constraint. 

There is more than one way to generalize optimal transport to unnormalized measures, and this is a 
subject of active study [20]. We will develop here a novel objective that deals effectively with the 
difference in total mass between h(x) and y while still being efficient to optimize. 



4.3 Relaxed transport 


We propose a novel relaxation that extends smoothed transport to unnormalized measures. By re¬ 
placing the equality constraints on the transport marginals in ([ 6 ]) with soft penalties with respect to 
KL divergence, we get an unconstrained approximate transport problem. The resulting objective is: 

X ’^ b W K L(h(-\x),y(-)) = min (T,M)-\H(T) + la KL(Tl\\h(x))+^{T T l\\y) ( 8 ) 

tgm^ x a 

where KL(w\\z) = w T log(w 0 z) — 1 T w + l T z is the generalized KL divergence between 
w,z £ R+. Here 0 represents element-wise division. As with the previous formulation, the optimal 
transport matrix with respect to ^ is a diagonal scaling of the matrix K. 

Proposition 4.1. The transport matrix T * optimizing satisfies T* = diag(iz)Kdiag(v), where 
u = (h{pc) 0 T*l) 7aA , v = (y 0 (T*) T l) 7bA , and K = e~ XM ~K 

And the optimal transport matrix is a fixed point for a Sinkhorn-like iteration. 

__ 7 a A _ 7a A 

Proposition 4.2. T* = diag(/u)Kdiag(u) optimizing satisfies: i)u = h(x)^+ T 0 (Kr) ^ A+1 , 

7b A 7b A 

and ii) v = ?/ 7 & A+1 0 (K u) 7bA+1 , where 0 represents element-wise multiplication. 


Unlike the previous formulation, ([ 8 } is unconstrained with respect to h{pc). The gradient is given by 
^h(x)WKL{h(-\x), y(-)) = 7 a (1 — T*1 0 h(x)). The iteration is given in Algorithm^!] 

When restricted to normalized measures, the relaxed problem © approximates smoothed transport 
( 6 ). Figure^ shows, for normalized h(x) and y, the relative distance between the values of ^ and 
( 6 )[^] For A large enough, ^ converges to ([ 6 ]) as and 7 ^ increase. 

© also retains two properties of smoothed transport { 6 ]). Figure [3J} shows that, for normalized 
outputs, the relaxed loss converges to the unregularized Wasserste in di stance as A, 7 a and 75 increase 
Q And Figure [3]: shows that convergence of the iterations in ( |4.2| ) is nearly independent of the 
dimension K of the output space. 


4.2 


is observed empirically to converge (see 


Note that, although the iteration suggested by Proposition 
Figure [3t, for example), we have not proven a guarantee that it will do so. 

3 In figures pa-c, h(x), y and M are generated as described in fl8l section 5. Inpa-b, h(x) and y have 
dimension 256. In 3c, convergence is defined as in ED. Shaded regions are 95% intervals. 

4 The unregularized Wasserstein distance was computed using FastEMD ED 
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p-th norm 

(a) Posterior predictions for images of digit 0. 




p-th norm 

(b) Posterior predictions for images of digit 4. 


Figure 4: MNIST example. Each curve shows the predicted probability for one digit, for models 
trained with different p values for the ground metric. 


5 Statistical Properties of the Wasserstein loss 

Let S = , (xjv, Vn)) be i.i.d. samples and be the empirical risk minimizer 

f i N 

h § = argmin < Eg [Wg(hg(-\x),y)] = — V Wg(h x 0(-\xi), y t ) 
h s en { ^ i=1 

Further assume FL = s o FL° is the composition of a softmax 5 and a base hypothesis space Fi° of 
functions mapping into R K . The softmax layer outputs a prediction that lies in the simplex A^. 

Theorem 5.1. Forp = 1, and any 5 > 0, with probability at least 1 — 5, it holds that 

E [Wyh § (-\x),y)] < inf E [Wl{h e {-\x), y)] + 32 KC m Kn(H°) + 2 (9) 
hoEl-L V ZiV 

with the constant Cm = max K)K / is the Rademacher complexity / l22l/ measuring 

the complexity of the hypothesis space FL°. 

The Rademacher complexity for commonly used models like neural networks and kernel 

machines [22 ] decays with the training set size. This theorem guarantees that the expected Wasser¬ 
stein loss of the empirical risk minimizer approaches the best achievable loss for FL. 

As an important special case, minimizing the empirical risk with Wasserstein loss is also good for 
multiclass classification. Let y = be the “one-hot” encoded label vector for the groundtruth class. 

Proposition 5.2. In the multiclass classification setting, for p = 1 and any 5 > 0, with probability 
at least 1 — 5, it holds that 

E*,„ [d K (K S (x),K)] < inf KE[Wl(h e (x), y)}+2,2K 2 C M d\ N {n°) + 2C M K\ ] ^-^ ( 10 ) 

heErL V ZiV 

where the predictor is Kq(x) = argmax^. \x), with h q being the empirical risk minimizer. 

Note that instead of the classification error E x ^[1{kq(x) k}], we actually get a bound on the 

expected semantic distance between the prediction and the groundtruth. 

6 Empirical study 

6.1 Impact of the ground metric 

In this section, we show that the Wasserstein loss encourages smoothness with respect to an artificial 
metric on the MNIST handwritten digit dataset. This is a multi-class classification problem with 
output dimensions corresponding to the 10 digits, and we apply a ground metric d p (n, n') = \k — 
n'\ p , where k,k! G {0,...,9} and p G [0, oo). This metric encourages the recognized digit to be 
numerically close to the true one. We train a model independently for each value of p and plot the 
average predicted probabilities of the different digits on the test set in Figure [4] 
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(a) Original Flickr tags dataset. 



5 10 15 20 

K (# of proposed tags) 

(b) Reduced-redundancy Flickr tags dataset. 


Figure 5: Top-K cost comparison of the proposed loss (Wasserstein) and the baseline (Divergence). 


Note that as p 0, the metric approaches the 0 — 1 metric do(«, n') = 1^, which treats all 
incorrect digits as being equally unfavorable. In this case, as can be seen in the figure, the predicted 
probability of the true digit goes to 1 while the probability for all other digits goes to 0. As p 
increases, the predictions become more evenly distributed over the neighboring digits, converging 
to a uniform distribution as p —» ocj^] 


6.2 Flickr tag prediction 

We apply the Wasserstein loss to a real world multi-label learning problem, using the recently re¬ 
leased Yahoo/Flickr Creative Commons 100M dataset l23l . Our goal is tag prediction : we select 
1000 descriptive tags along with two random sets of 10,000 images each, associated with these tags, 
for training and testing. We derive a distance metric between tags by using word2vec lf24l to 
embed the tags as unit vectors, then taking their Euclidean distances. To extract image features we 
use MatConvNet [25 1. Note that the set of tags is highly redundant and often many semantically 
equivalent or similar tags can apply to an image. The images are also partially tagged, as different 
users may prefer different tags. We therefore measure the prediction performance by the top-K cost , 
defined as Ck = 1/iT Ylk=i m i n j /%•)> where {kj} is the set of groundtruth tags, and {£/-} 

are the tags with highest predicted probability. The standard AUC measure is also reported. 

We find that a linear combination of the Wasserstein loss Wg and the standard multiclass logistic loss 
KL yields the best prediction results. Specifically, we train a linear model by minimizing W£ + aKL 
on the training set, where a controls the relative weight of KL. Note that KL taken alone is our 
baseline in these experiments. Figure [5^ shows the top-K cost on the test set for the combined loss 
and the baseline KL loss. We additionally create a second dataset by removing redundant labels 
from the original dataset: this simulates the potentially more difficult case in which a single user 
tags each image, by selecting one tag to apply from amongst each cluster of applicable, semantically 
similar tags. Figure 3b shows that performance for both algorithms decreases on the harder dataset, 
while the combined Wasserstein loss continues to outperform the baseline. 

In Figure [6} we show the effect on performance of varying the weight a on the KF loss. We observe 
that the optimum of the top-AT cost is achieved when the Wasserstein loss is weighted more heavily 
than at the optimum of the AUC. This is consistent with a semantic smoothing effect of Wasserstein, 
which during training will favor mispredictions that are semantically similar to the ground truth, 
sometimes at the cost of lower AUC Q We finally show two selected images from the test set in 
Figure [7] These illustrate cases in which both algorithms make predictions that are semantically 
relevant, despite overlapping very little with the ground truth. The image on the left shows errors 
made by both algorithms. More examples can be found in the appendix. 

5 To avoid numerical issues, we scale down the ground metric such that all of the distance values are in the 
interval [0,1). 

6 The dataset used here is available at http: //cbcl. mit. edu/wasserstein 

7 Th e Wa sserstein loss can achieve a similar trade-off by choosing the metric parameter p, as discussed in 
Section [67T| However, the relationship between p and the smoothing behavior is complex and it can be simpler 
to implement the trade-off by combining with the KL loss. 
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(b) Reduced-redundancy Flickr tags dataset. 


Figure 6: Trade-off between semantic smoothness and maximum likelihood. 



(a) Flickr user tags: street, parade, dragon; our 
proposals: people, protest, parade; baseline pro¬ 
posals: music, car, band. 



(b) Flickr user tags: water, boat, reflection, sun¬ 
shine; our proposals: water, river, lake, summer; 
baseline proposals: river, water, club, nature. 


Figure 7: Examples of images in the Flickr dataset. We show the groundtruth tags and as well as 
tags proposed by our algorithm and the baseline. 


7 Conclusions and future work 

In this paper we have described a loss function for learning to predict a non-negative measure over a 
finite set, based on the Wasserstein distance. Although optimizing with respect to the exact Wasser- 
stein loss is computationally costly, an approximation based on entropic regularization is efficiently 
computed. We described a learning algorithm based on this regularization and we proposed a novel 
extension of the regularized loss to unnormalized measures that preserves its efficiency. We also 
described a statistical learning bound for the loss. The Wasserstein loss can encourage smoothness 
of the predictions with respect to a chosen metric on the output space, and we demonstrated this 
property on a real-data tag prediction problem, showing improved performance over a baseline that 
doesn’t incorporate the metric. 

An interesting direction for future work may be to explore the connection between the Wasserstein 
loss and Markov random fields, as the latter are often used to encourage smoothness of predictions, 
via inference at prediction time. 
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A Relaxed transport 

Equation ([ 8 ]} gives the relaxed transport objective as 

^»W KL (h(-\x),y(-)) = min < T,M ) - \h(T) + 7a KL (Tl\\h(x)) + 7fe KL (T T l||y) 

TeM^ x A 

with KL (w\\z) = w T log(w 0 z) — 1 T w + 1 T z. 

Proof of Proposition \4.1\ The first order condition for T* optimizing ([SJ is 

Mij + \ (log 1% + 1 ) + 7 a (logT* 1 0 h(x)\ + lb (log(T*) T l 0 y). = 0 . 

=> log T*j + 7 a A log (T* 1 0 h{x i )) i + 7 bA log ((T*) t 1 0 Vj ) . = - AM y - 1 
(T*l 0 /i(*))0 ((^*) T 1 0 2/)J bA = exp (-AAfy - 1) 

= (h(x) 0 T*1) 7 “ A (y 0 (T*) t 1) J bA exp (—AM»j - 1) 


Hence T* (if it exists) is a diagonal scaling of K = exp (—AM — 1). 


□ 


Proof of Proposition \4.2\ Let u = (/i(i)0fl) 7aA and v = (y 0 (T*) T l) 7bA , so T* = 
diag(iz)Kdiag(v). We have 

T*1 = diag(?i)Ku 
=> (T*l) 7aA+1 = h(x)^ x © Kn 


where we substituted the expression for u. Re-writing T*l, 

(diag(^i)Ku ) 7aA+1 = diag(ft(x) 7aA )Ku 

^ u 7aA+t = h ( x ya\ Q (Kv) _7aA 
la A 'Ya A 

= ft(x) TaA + 1 Q (Ku)~ 7aA + i . 


7 b A T A 

A symmetric argument shows that v = |/ 7 i > a+1 0 (K'w) 7 !> x+1 . □ 

B Statistical Learning Bounds 

We establish the proof of Theorem |5.1| in this section. For simpler notation, for a sequence S = 
((# 1 , 2/1 )>..., (xjv, 7 /jv)) of i.i.d. training samples, we denote the empirical risk Rs and risk R as 

Rs(h e )=E s [WP(he(-\x),y(-))], R(he) = E [Wg(ho(-\x), y(-))] (11) 

Lemma B.l. Let ho* E PL be the minimizer of the empirical risk R$ and expected risk R, 
respectively. Then 

R(h § ) < R(ho*) + 2 sup | R(h) — Rs{h)\ 
hen 


Proof By the optimality of hg for Rs , 

R(h§) ~ R{ho *) = R(h§) - Rs{h§) + Rsi^o) ~ R(ho*) 

< R(h s ) - Rs{h§) + Rs{ho*) - R{ho*) 

< 2 sup |i?(ft) — i?s(ft)| 

/iE^ 

□ 
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Therefore, to bound the risk for h^, we need to establish uniform concentration bounds for the 
Wasserstein loss. Towards that goal, we define a space of loss functions induced by the hypothesis 
space PL as 

C = {£ e : (x,y) ^ W*(h e (-\x),y(-)) :h e ePL} (12) 

The uniform concentration will depends on the “complexity” of £, which is measured by the empir¬ 
ical Rademacher complexity defined below. 

Definition B.2 (Rademacher Complexity | 22)). Let Q be a family of mapping from Z to M, and 
S = (zi ,..., zn) a fixed sample from Z. The empirical Rademacher complexity of Q with respect 
to S is defined as 


&s(G) = E„ 


1 x ^ 

SU P 'Z a id( z i) 

9£G N i= i 


(13) 


where a = (<ti, ..., cr at), with crfs independent uniform random variables taking values in 
{+1,-1}. (Ji’s are called the Rademacher random variables. The Rademacher complexity is de¬ 
fined by taking expectation with respect to the samples S, 


*N(G)=E s [Vis(G)\ (14) 

Theorem B.3. For any S > 0, with probability at least 1 — 5, the following holds for all £$ G C, 

E[4]-E 5 [4] <2 Kn(£)~ 

with the constant Cm = max K)/t / M K)K /. 


ClMl/5) 


2N 


(15) 


By the definition of £, E [£q] = R(he) and E s[£e\ = R>s[ho]- Therefore, this theorem provides a 
uniform control for the deviation of the empirical risk from the risk. 

Theorem B.4 (McDiarmid’s Inequality). Let S = {Xi, ..., X/v} C 3C be N i.i.d. random vari¬ 
ables. Assume there exists C > 0 such that f : 2Z N —> M satisfies the following stability condition 

| f (x 1 5 • * • 5 %i 1 • • • 1 X tv) f {x\ 5 ... 5 X^ 5 ... 5 a^Tv) | ^ C (lfi) 

for all i = 1, .. ., N and any aq, ..., xn, x\ G . Then for any e > 0, denoting f(X i,..., Xn) 
by f(S), it holds that 


P (f(S) - E[/(5)] > e) < exp ) 

Lemma B.5. Let the constant Cm = max K)K / then 0 < W^(-, •) < Cm- 


(17) 


Proof. For any hf \x) and y(-), let T* G IL(h(x),y) be the optimal transport plan that solves ([3j, 
then 

W*(h(x),y) = (T*,M) < C M ^T K , K . = 


□ 


Proof of Theorem \BJ\ For any £q G £, note the empirical expectation is the empirical risk of the 
corresponding he : 

i N i N 

Es[ie\ = Rs(he) 


i= 1 


Similarly, E [£q\ = R{he). Let 


&(S) =ssupE[f] — Es[/| 
iec 


(18) 


Let S' be S with the i-th sample replaced by (x-, y'f), by Lemma | b 3| it holds that 


$(S) - $(S") < sup Ks' [£] - E s[£] = sup 
£e£ h e £H 


W^x^^-W^he^)^) < Cm 


N 


N 
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Similarly, we can show 3>(S") — <£(5) < Cm/N, thus |T>(S") — &(S) | < Cm/N. By Theorem B.4 
for any S > 0 , with probability at least 1 — S, it holds that 


$(S) < E[$(S)] 

To bound E[T>(*S')], by Jensen’s inequality, 


2N 


(19) 


E S [$(5)] =E S 


supE[^] — Ks[£\ 
jec 


Es 


sup E s' 
jec 


E S 'll}-Esit] 


< E 5 ^/ 


SUpE^] - 

jec 


Here 5" is another sequence of i.i.d. samples, usually called ghost samples , that is only used for 
analysis. Now we introduce the Rademacher variables cr^, since the role of S and S' are completely 
symmetric, it follows 


Es[$(s)] <e 5 , 5 , ; 


1 N 

Sup AT E y'i ) - 2/i)) 

rec ■ /v 


< E s > 


E s 


1 N 

9l s (£)] + E S , [^(£) 


’ E5 5 c 




AT 


= 29Jat(E) 

The conclusion follows by combing (]~8j ) and ( p~9] ). 


□ 


To finish the proof of Theorem |5.1| we combine Lemma B.l and Theorem B.3 and relate 93 a j(C) 
to 93 at(H) via the following generalized Talagrand’s lemma 1261. 


Lemma B.6. Let T be a class of real functions, and TL C T = T\ x ... x Tk he a K-valued 
function class. If m : R K —> M is a L m -Lipschitz function and m(0) = 0, then 93 s (m oH) < 
2£ m Z? = iXs(r k ). 

Theorem B.7 (Theorem 6.15 of Let p and v he two probability measures on a Polish space 
(1C, djc). Let p G [1, oo) and kq G 1C. Then 


/ r \ 1 /P i i 

W p (n,v)<2 1/p U dtc{Ko,K)d\fi-i>\(K)j , - + —= 1 (20) 

Corollary B.8. The Wasserstein loss is Lipschitz continuous in the sense that for any h$ G TL, and 
any (x, y) G X x y, 


Wp(hg(-\x),y) < 2 P 1 C M \ h o( K \ x ) ~ v( K )\ 

kEK. 

In particular, when p = 1, we have 

Wphg(-\x),y) <C M J2 M*!*) - l/(«)I 

K, £/C 


( 21 ) 


( 22 ) 


We cannot apply Lemma [ eT 6 | directly to the Wasserstein loss class, because the Wasserstein loss is 
only defined on probability distributions, so 0 is not a valid input. To get around this problem, we 
assume the hypothesis space T-L used in learning is of the form 


n = {s o h° : h° G H°} 


(23) 


where T-L° is a function class that maps into R A ', and s is the softmax function defined as s(o) = 
(si(o),... ,s k(o)), with 


e° k 

Sk(o) = =——, k = l,...,K 
Ei e 3 


(24) 


The softmax layer produce a valid probability distribution from arbitrary input, and this is consistent 
with commonly used models such as Logistic Regression and Neural Networks. By working with 
the log of the groundtruth labels, we can also add a softmax layer to the labels. 
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Lemma B.9 (Proposition 2 of |27l)* The Wasserstein distances W p (-, •) are metrics on the space of 
probability distributions of JC, for all 1 < p < oo. 

Proposition B.10. The map i : R K x R K R defined by i{y , y') = Wl(s(y) : s(y')) satisfies 

Hy,y') - i'(y,y ')I < 4Cm||(2/,2/') - ( 27 , 2 /') Ih (25) 

/or (■ y , y'), (y, y r ) G x M K . A^d *,(0, 0) = 0. 


Proof For any ($/, 2/'), ( y , 2/') 
Wasserstein loss, 


, by Lemma B.9 we can use triangle inequality on the 


Hy,y') - i'(y,y')\ = Hy,y') - *•( y,y ') + a(i/, 2/') - *•(?/, 2/) I < %,y) + iG/,2/) 

Following Corollary |B. 8 [ it continues as 

Hy,y') ~ i'(y,y') I < Cm (II \s(y) -s(y)||i + \\s(y') -s(y')lli) (26) 

Note for each k = 1,..., if, the gradient satisfies 


|JV^||2 — 


f dsC\ K 


\dyj)j= i 

to 

1 


- SkSj)f =1 


\ 


K 


s lE s J 2+s fe( 1 -2Sfc) (27) 

J = 1 


By mean value theorem, 3a G [0,1], such that for yQ m ay + (1 — a)?/, it holds that 


K 


K 


IK</) -s(y)|)i = E \(Vy s k\y=y ak ’y-y) < E \\ V v s k\y=y ak h\\y -yh < 2||y — y|| 2 


k=l 


k=1 


because by ( [27] ), and the fact that 5< j — J2j 5 j = 1 an d \J a + b < y/a + for a, b > 0, it 
holds 

K 

En v ^ii2= E iv?/ s fcll2+ E ll v y s fcll2 

k=l k:s k < 1/2 k:Sk> 1/2 


K 

< (s/c + 5/cVl — 2Sfe) + 5k E 25/c = 2 

k:s k <l/2 k:s k >l/2 k=1 

Similarly, we have ||s(i/) — 5 ( 2 /') ||i < 2||?/ — ^'|| 2 , so from ( [26] ), we know 


k(y, i/O - t(i/, y')l < 2C M (||y - vh + II y' - y'h) < 2 V2C M (|| y - y||| + y - y'\\l) 1/2 

then ( [25] ) follows immediately. The second conclusion follows trivially as 5 maps the zero vector to 
a uniform distribution. □ 


Proof of Theorem [377] Consider the loss function space preceded with a softmax layer 

c = {ie : (X,y) ^ Wt( 5 (h° e (x)),s(y)) : h° e € U°} 


We apply Lemma BT) to the 4CM-Lipschitz continuous function t in Proposition |B.1Q and the 
function space 

U° x ...xU° xXx ...xX 


K copies 


K copies 


with X a singleton function space with only the identity map. It holds 

&s(£) < 8 C M (K& s {n°) + K&sP)) = 8 KC M ^s{n°) 

because for the identity map, and a sample S = (t/i, •.., t/jv), we can calculate 


£MZ) = E„ 


1 N 


U ex - ■ i=l 

The conclusion of the theorem follows by combining 


= E „ 


1 N 


i =1 


= 0 


(28) 


with Theorem B.3 and Lemma B. 1 □ 
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C Connection with multiclass classification 


Proof of Proposition ^^ Given that the label is a “one-hot” vector y = <e K , the set of transport plans 
0 degenerates. Specifically, the constraint T t 1 = means that only the ft-th column of T can 
be non-zero. Furthermore, the constraint T 1 = h$(-\x) ensures that the k -th column of T actually 
equals h § (-\x). In other words, the set n (h§^ x ye K ) contains only one feasible transport plan, so 
0 can be computed directly as 

Wp{hfj(-\x),e K ) = Y M K ,, K h 9 (K'\x) = Y <&(*'» 

E/C k/E/C 

Now let k = argmax^ h§(K,\x) be the prediction, we have 

h d (k\x) = 1 - Y h e ( K \ x ) > 1 - Y h e ( k \ x ) = 1 - (K - l)h § (k\x) 

pi^k pi^k 

Therefore, h§(k\x) > 1/K, so 

w*(h§{-\x),<B K ) >d^(k^)h § (k\x)>d p ^(k^)/K 

The conclusion follows by applying Theorem |5. 1 1 with p = 1. □ 


D Algorithmic Details of Learning with a Wasserstein Loss 


In Section [5] we describe the statistical generalization properties of learning with a Wasserstein loss 
function via empirical risk minimization on a general space of classifiers T-L. In all the empirical 
studies presented in the paper, we use the space of linear logistic regression classifiers, defined by 


n = 



expjdjx) 

Y,f=i e M6j x ) 


:6 k eR D ,k = 1 > 


We use stochastic gradient descent with a mini-batch size of 100 samples to optimize the empirical 
risk, with a standard regularizer 0.0005 Y^k=i ll^fcll! on the weights. The algorithm is described 
in Algorithm [2j where WASSERSTEIN is a sub-routine that computes the Wasserstein loss and its 
subgradient via the dual solution as described in Algorithm [T] We always run the gradient descent 
for a fixed number of 100,000 iterations for training. 


Algorithm 2 SGD Learning of Linear Logistic Model with Wasserstein Loss 

Init 0 1 randomly, 
for t = 1,... ,T do 

Sample mini-batch V t = (xi, yi ),..., (x n , y n ) from the training set. 

Compute Wasserstein subgradient dWg / dhe\et Wasserstein(2^, hgt (•)). 
Compute parameter subgradient dW^/dO\gt = (dhe/d0)(dW^/dhe)\e* 
Update parameter 0 t+1 <- — r] t dW^/dO\ot 

end for 


Note that the same training algorithm can easily be extended from training a linear logistic regres¬ 
sion model to a multi-layer neural network model, by cascading the chain-rule in the subgradient 
computation. 

E Empirical study 

E.l Noisy label example 

We simulate the phenomenon of label noise arising from confusion of semantically similar classes 
as follows. Consider a multiclass classification problem, in which the labels correspond to the 
vertices on a D x D lattice on the 2D plane. The Euclidean distance in M 2 is used to measure the 
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(a) Noise level 0.1 (b) Noise level 0.5 

Figure 8: Illustration of training samples on a 3x3 lattice with different noise levels. 


semantic similarity between labels. The observations for each category are samples from an isotropic 
Gaussian distribution centered at the corresponding vertex. Given a noise level t, we choose with 
probability t to flip the label for each training sample to one of the neighboring categorie^] chosen 
uniformly at random. Figure [8] shows the training set for a 3 x 3 lattice with noise levels t — 0.1 and 
t = 0.5, respectively. 

Figure [2] is generated as follows. We repeat 10 times for noise levels t = 0.1,0.2,...,0.9 and 
D = 3,4,..., 7. We train a multiclass linear logistic regression classifier (as described in sectio n [P] 
of the Appendix) using either the standard KL-divergence los^Jor the proposed Wasserstein losj^f 
The performance is measured by the mean Euclidean distance in the plane between the predicted 
class and the true class, on the test set. Figure [2] compares the performance of the two loss functions. 


E.2 Full figure for the MNIST example 


The full version of Figure [4] from Section 6.1 is shown in Figure[9j 


E.3 Details of the Flickr tag prediction experiment 

From the tags in the Yahoo Flickr Creative Commons dataset, we filtered out those not occurring 
in the WordNef"*] database, as well those whose dominant lexical category was ’’noun.location” or 
’’noun.time.” We also filtered out by hand nouns referring to geographical location or nationality, 
proper nouns, numbers, photography-specific vocabulary, and several words not generally descrip¬ 
tive of visual content (such as ’’annual” and ’’demo”). From the remainder, the 1000 most frequently 
occurring tags were used. 

We list some of the 1000 selected tags here. The 50 most frequently occurring tags: travel, square, 
wedding, art, flower, music, nature, party, beach, family, people, food, tree, summer, water, concert, 
winter, sky, snow, street, portrait, architecture, car, live, trip, friend, cat, sign, garden, mountain, 
bird, sport, light, museum, animal, rock, show, spring, dog, film, blue, green, road, girl, event, red, 

8 Connected vertices on the lattice are considered neighbors, and the Euclidean distance between neighbors 
is set to 1. 

9 This corresponds to maximum likelihood estimation of the logistic regression model. 

10 In this special case, this corresponds to weighted maximum likelihood estimation, c.f. Section |c| 

11 http://wordnet.princeton.edu 
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0 12 3 4 

p-th norm 


(a) Posterior prediction for images of digit 0. 

Figure 9: Each curve is the predicted probability 
p values for the ground metric. 



p-th norm 


(b) Posterior prediction for images of digit 4. 
a target digit from models trained with different 


fun, building, new, cloud. ... and the 50 least frequent tags: arboretum, chick, sightseeing, vineyard, 
animalia, burlesque, key, flat, whale, swiss, giraffe, floor, peak, contemporary, scooter, society, actor, 
tomb, fabric, gala, coral, sleeping, lizard, performer, album, body, crew, bathroom, bed, cricket, 
piano, base, poetry, master, renovation, step, ghost, freight, champion, cartoon, jumping, crochet, 
gaming, shooting, animation, carving, rocket, infant, drift, hope. 

The complete features and labels can also be downloaded from the project websitj^] We train a 
multiclass linear logistic regression model with a linear combination of the Wasserstein loss and 
the KL divergence-based loss. The Wasserstein loss between the prediction and the normalized 
groundtruth is computed as described in Algorithm [T] using 10 iterations of the Sinkhorn-Knopp 
algorithm. Based on inspection of the ground metric matrix, we use p-norm with p = 13, and set 
A = 50. This ensures that the matrix K is reasonably sparse, enforcing semantic smoothness only in 
each local neighborhood. Stochastic gradient descent with a mini-batch size of 100, and momentum 
0.7 is run for 100,000 iterations to optimize the objective function on the training set. The baseline 
is trained under the same setting, using only the KL loss function. 

To create the dataset with reduced redundancy, for each image in the training set, we compute the 
pairwise semantic distance for the groundtruth tags, and cluster them into “equivalent” tag-sets with 
a threshold of semantic distance 1.3. Within each tag-set, one random tag is selected. 

Figure [T0| shows more test images and predictions randomly picked from the test set. 


pttp://cbcl.mit.edu/wasserstein/ 
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(a) Flickr user tags: zoo, run, 
mark; our proposals: running, 
summer, fun; baseline proposals: 
running, country, lake. 



(b) Flickr user tags: travel, ar¬ 
chitecture, tourism; our proposals: 
sky, roof, building; baseline pro¬ 
posals: art, sky, beach. 


(c) Flickr user tags: spring, race, 
training; our proposals: road, bike, 
trail; baseline proposals: dog, 
surf, bike. 



(d) Flickr user tags: family, trip, house; our propos- (e) Flickr user tags: education, weather, cow, agricul- 
als: family, girl, green; baseline proposals: woman, ture; our proposals: girl, people, animal, play; base- 
tree, family. line proposals: concert, statue, pretty, girl. 



(f) Flickr user tags: garden, table, gardening; our 
proposals: garden, spring, plant; baseline proposals: 
garden, decoration, plant. 


(g) Flickr user tags: nature, bird, rescue; our propos¬ 
als: bird, nature, wildlife; baseline proposals: ature, 
bird, baby. 


Figure 10: Examples of images in the Flickr dataset. We show the groundtruth tags and as well as 
tags proposed by our algorithm and baseline. 
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